{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nCompile PyTorch Object Detection Models\n=======================================\nThis article is an introductory tutorial to deploy PyTorch object\ndetection models with Relay VM.\n\nFor us to begin with, PyTorch should be installed.\nTorchVision is also required since we will be using it as our model zoo.\n\nA quick solution is to install via pip\n\n.. code-block:: bash\n\n    pip install torch==1.7.0\n    pip install torchvision==0.8.1\n\nor please refer to official site\nhttps://pytorch.org/get-started/locally/\n\nPyTorch versions should be backwards compatible but should be used\nwith the proper TorchVision version.\n\nCurrently, TVM supports PyTorch 1.7 and 1.4. Other versions may\nbe unstable.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tvm\nfrom tvm import relay\nfrom tvm import relay\nfrom tvm.runtime.vm import VirtualMachine\nfrom tvm.contrib.download import download_testdata\n\nimport numpy as np\nimport cv2\n\n# PyTorch imports\nimport torch\nimport torchvision"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Load pre-trained maskrcnn from torchvision and do tracing\n---------------------------------------------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "in_size = 300\n\ninput_shape = (1, 3, in_size, in_size)\n\n\ndef do_trace(model, inp):\n    model_trace = torch.jit.trace(model, inp)\n    model_trace.eval()\n    return model_trace\n\n\ndef dict_to_tuple(out_dict):\n    if \"masks\" in out_dict.keys():\n        return out_dict[\"boxes\"], out_dict[\"scores\"], out_dict[\"labels\"], out_dict[\"masks\"]\n    return out_dict[\"boxes\"], out_dict[\"scores\"], out_dict[\"labels\"]\n\n\nclass TraceWrapper(torch.nn.Module):\n    def __init__(self, model):\n        super().__init__()\n        self.model = model\n\n    def forward(self, inp):\n        out = self.model(inp)\n        return dict_to_tuple(out[0])\n\n\nmodel_func = torchvision.models.detection.maskrcnn_resnet50_fpn\nmodel = TraceWrapper(model_func(pretrained=True))\n\nmodel.eval()\ninp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))\n\nwith torch.no_grad():\n    out = model(inp)\n    script_module = do_trace(model, inp)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Download a test image and pre-process\n-------------------------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "img_url = (\n    \"https://raw.githubusercontent.com/dmlc/web-data/\" \"master/gluoncv/detection/street_small.jpg\"\n)\nimg_path = download_testdata(img_url, \"test_street_small.jpg\", module=\"data\")\n\nimg = cv2.imread(img_path).astype(\"float32\")\nimg = cv2.resize(img, (in_size, in_size))\nimg = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\nimg = np.transpose(img / 255.0, [2, 0, 1])\nimg = np.expand_dims(img, axis=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Import the graph to Relay\n-------------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "input_name = \"input0\"\nshape_list = [(input_name, input_shape)]\nmod, params = relay.frontend.from_pytorch(script_module, shape_list)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Compile with Relay VM\n---------------------\nNote: Currently only CPU target is supported. For x86 target, it is\nhighly recommended to build TVM with Intel MKL and Intel OpenMP to get\nbest performance, due to the existence of large dense operator in\ntorchvision rcnn models.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Add \"-libs=mkl\" to get best performance on x86 target.\n# For x86 machine supports AVX512, the complete target is\n# \"llvm -mcpu=skylake-avx512 -libs=mkl\"\ntarget = \"llvm\"\n\nwith tvm.transform.PassContext(opt_level=3, disabled_pass=[\"FoldScaleAxis\"]):\n    vm_exec = relay.vm.compile(mod, target=target, params=params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Inference with Relay VM\n-----------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dev = tvm.cpu()\nvm = VirtualMachine(vm_exec, dev)\nvm.set_input(\"main\", **{input_name: img})\ntvm_res = vm.run()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Get boxes with score larger than 0.9\n------------------------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "score_threshold = 0.9\nboxes = tvm_res[0].numpy().tolist()\nvalid_boxes = []\nfor i, score in enumerate(tvm_res[1].numpy().tolist()):\n    if score > score_threshold:\n        valid_boxes.append(boxes[i])\n    else:\n        break\n\nprint(\"Get {} valid boxes\".format(len(valid_boxes)))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}